from contextlib import contextmanager
import numpy as np
import torch
import matplotlib.pyplot as plt
import torch.nn as nn

from core.components.base_model import BaseModel
from core.modules.losses import L2Loss
from core.modules.subnetworks import Encoder, Decoder, Predictor
from core.modules.recurrent_modules import RecurrentPredictor
from core.utils.general_utils import AttrDict, ParamDict, batch_apply
from core.utils.pytorch_utils import pad_seq
from core.modules.layers import LayerBuilderParams
from core.utils.vis_utils import make_gif_strip, make_image_seq_strip, make_image_strip
from core.data.bouncing_sprites.utils.template_blender import TemplateBlender
from core.data.bouncing_sprites.moving_sprites import NaturalDistractorMPPIMovingSpritesGenerator 

class StateModel(BaseModel):
    def __init__(self, params, logger):
        super().__init__(logger)
        self._hp = self._default_hparams()
        self._hp.overwrite(params)
        self._hp.builder = LayerBuilderParams(self._hp.use_convs, self._hp.normalization)

        if self._hp.state_vis:
            self._template_blender = TemplateBlender((self._hp.img_sz, self._hp.img_sz))
            self._generator = NaturalDistractorMPPIMovingSpritesGenerator(self._hp)
        self.build_network()

    @contextmanager
    def val_mode(self):
        pass
        yield
        pass

    def _default_hparams(self):
        default_dict = ParamDict({
            'use_skips': True,
            'skips_stride': 2,
            'add_weighted_pixel_copy': False, # if True, adds pixel copying stream for decoder
            'pixel_shift_decoder': False,
            'use_convs': True,
            'detach_reconstruction': True,
            'detach_state_head': False,
            'n_cond_frames': 1,
            'state_vis': False
        })

        # Network size
        default_dict.update({
            'img_sz': 32,
            'input_nc': 3,
            'ngf': 8,
            'nz_enc': 32,
            'nz_mid': 32,
            'output_size': 4,
            'n_processing_layers': 3,
            'n_pixel_sources': 1,
        })

        # Loss weights
        default_dict.update({ 'img_mse_weight': 1.,
            'state_weight': 1.,
        })

        # add new params to parent params
        parent_params = super()._default_hparams()
        parent_params.overwrite(default_dict)
        return parent_params

    def build_network(self):
        self.encoder = Encoder(self._hp)
        self.decoder = Decoder(self._hp)
        self.state_decoder = Predictor(self._hp, input_size=self._hp.nz_enc,
                                       output_size=self._hp.output_size, spatial=False)

    def forward(self, inputs):
        """
        forward pass at training time
        """
        output = AttrDict()

        # encode inputs
        enc, skips = self.encoder(inputs.images[:, 0])
        skips = [s.detach() if s is not None else s for s in skips] # only use skips from last conditioning frame
        output.update({'pred': enc})

        rec_input = output.pred.detach() if self._hp.detach_reconstruction else output.pred
        if len(rec_input.shape) == 2:
            rec_input = rec_input[..., None, None]
        output.output_imgs = self.decoder(rec_input, skips=skips).images.unsqueeze(1)
                                                              # pixel_sources=[inputs.images[:, self._hp.n_cond_frames]]),

        # state decoding
        state_input = output.pred.detach() if self._hp.detach_state_head else output.pred
        if len(state_input.shape) == 2:
            state_input = state_input[..., None, None]
        output.states = self.state_decoder(state_input)

        return output

    def forward_head(self, inputs):
        enc, skips = self.encoder(inputs)
        return self.state_decoder(enc)

    def forward_encoder(self, inputs):
        enc, skips = self.encoder(inputs)
        return enc

    def loss(self, model_output, inputs):
        losses = AttrDict()

        # image reconstruction loss
        losses.seq_img_mse = L2Loss(self._hp.img_mse_weight)(model_output.output_imgs,
                                                             inputs.images)

        # reward regression loss
        losses.state = L2Loss(self._hp.state_weight)(model_output.states,
                                                     inputs.states.squeeze(1))

        losses.total = self._compute_total_loss(losses)
        return losses

    def log_outputs(self, model_output, inputs, losses, step, log_images, phase):
        super()._log_losses(losses, step, log_images, phase)
        if log_images:
            # log predicted images
            # img_strip = make_image_strip([inputs.images[:, 0, -1:].repeat_interleave(3, dim=1), model_output.output_imgs[:, 0, -1:].repeat_interleave(3, axis=1)])
            img_strip = make_image_strip([inputs.images.squeeze(1), model_output.output_imgs.squeeze(1)])
            self._logger.log_images(img_strip[None], 'generation', step, phase)

            if self._hp.state_vis:
                sprites = [self._generator._shape_sprites[shape] for shape in np.array(self._generator.SHAPES)[self._generator._sample_shapes()]]

                img_input = model_output.states[0].detach().cpu().numpy()
                img_input = img_input.reshape(1, 2, 2)

                pred_state = self._template_blender.create((img_input * (self._hp.img_sz - 1)).astype(int), sprites)
                self._logger.log_images(pred_state[None], 'pred_state', step, phase)
                gt_state = self._template_blender.create((inputs.states.detach().cpu().numpy()[0].reshape(1, 2, 2) * (self._hp.img_sz - 1)).astype(int), sprites)
                self._logger.log_images(gt_state[None], 'gt_state', step, phase)

            loss_per_step = L2Loss()(model_output.states, inputs.states)
            self._logger.log_scalar(loss_per_step, 'state_loss_breakdown', step, phase)

    @property
    def resolution(self):
        return self._hp.img_sz


